package poolserver

import (
	"bytes"
	"encoding/binary"
	"errors"
	"net"

	//"net/http"

	"crypto/md5"
	"crypto/rand"
	"encoding/hex"

	"go.uber.org/zap"
)

/*PG的md5的计算
 * concat('md5', md5(concat(md5(concat(password, username)), random-salt)))
 */

func md5AuthCalc(buf []byte, user string, password string, salt []byte) {
	var tmpbuf [256]byte
	var pos int32
	var n int32

	//把dstBuf中先拷贝入“md5”
	copy(buf, []byte("md5"))

	// 先把password和user拼接到tmpbuf中
	pos = int32(copy(tmpbuf[0:], []byte(password)))
	n = int32(copy(tmpbuf[pos:], []byte(user)))
	pos += n

	// 把password和user拼接的结果做一次md5运算，结果仍然放到tmpbuf中
	h := md5.New()
	h.Write(tmpbuf[:pos])
	hex.Encode(tmpbuf[0:], h.Sum(nil))

	// 把第一次md5的结果与salt进行拼接
	n = int32(copy(tmpbuf[32:], salt[:4]))
	// 进行第二次md5运算
	h = md5.New()
	h.Write(tmpbuf[:32+n])
	hex.Encode(buf[3:], h.Sum(nil))
}

/*
* 解析key/value的消息内容ParameterStatus消息
 */

func parseKeyValuePacket(buf []byte) (string, string) {
	var i uint32
	var isValue = false
	var dataLen = binary.BigEndian.Uint32(buf[0:4])
	var begin uint32
	var key = ""
	var value = ""

	begin = 4
	i = 4
	for i < dataLen {
		if buf[i] == 0 {
			if !isValue {
				key = string(buf[begin:i])
				isValue = true
			} else {
				value = string(buf[begin:i])
				return key, value
			}
			begin = i + 1
		}
		i++
	}
	return key, value
}

/*
接收启动包，注意启动包没有类型字段（即第一个字节）
*/
func recvStartupPacket(conn net.Conn, buf []byte) (int32, error) {
	var ret int
	var err error
	var pos int32 = 0
	var dataLen int32

	for pos < 4 {
		ret, err = conn.Read(buf[pos:4])
		if err != nil {
			conn.Close()
			return 0, err
		}
		pos += int32(ret)
	}

	dataLen = int32(binary.BigEndian.Uint32(buf[0:4]))
	for pos < dataLen {
		ret, err = conn.Read(buf[pos:dataLen])
		if err != nil {
			zap.S().Infof("Recv from client error(in recvStartupPacket) : %s", err.Error())
			conn.Close()
			return 0, err
		}
		pos += int32(ret)
	}
	return pos, nil
}

/*
 * 解析多个key/value对的内容，如启动消息中的内容，就是由多对key/value，结束时有两个\0\0
 */
func parseMulKeyValuePacket(buf []byte, dataLen int32) map[string]string {
	kvs := make(map[string]string)

	var i int32 = 0
	var begin = i
	var isValue = false
	var key string

	for i < dataLen {
		if buf[i] == 0 {
			if !isValue {
				key = string(buf[begin:i])
				isValue = true
			} else {
				kvs[key] = string(buf[begin:i])
				isValue = false
				//zap.S().Infof("%s=%s", key, kvs[key])
			}
			begin = i + 1
			if i < dataLen-1 {
				/* 如果后面连续有两个0，则认为是结束*/
				if buf[i+1] == 0 {
					break
				}
			}
		}
		i++
	}
	return kvs
}

func sendPqAuthenticationOk(conn net.Conn) error {
	var buf [9]byte
	buf[0] = 'R'
	binary.BigEndian.PutUint32(buf[1:], 8)
	binary.BigEndian.PutUint32(buf[5:], 0) // 0表示认证成功
	_, err := sendData(conn, buf[0:9])
	return err
}

func sendPqParameterStatus(conn net.Conn, key string, value string) error {
	var buf = make([]byte, 5, 128)
	buf[0] = 'S'
	binary.BigEndian.PutUint32(buf[1:], 0)
	buf = append(buf, []byte(key)...)
	buf = append(buf, 0)
	buf = append(buf, []byte(value)...)
	buf = append(buf, 0)
	binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1))
	_, err := sendData(conn, buf)
	return err
}

func sendPqBackendKeyData(conn net.Conn, pid uint32, key uint32) error {
	var buf [13]byte
	buf[0] = 'K'
	binary.BigEndian.PutUint32(buf[1:], 12)
	binary.BigEndian.PutUint32(buf[5:], pid)
	binary.BigEndian.PutUint32(buf[9:], key) // 0表示认证成功
	_, err := sendData(conn, buf[:13])
	return err
}

func handleAuth(cliConn net.Conn, Pid uint32) (int, *Pool) {
	var buf []byte
	var rbuf []byte
	var dataLen int32
	var pos int32 = 0
	//var n int32
	var err error
	var ver1 uint16
	var ver2 uint16

	var pool *Pool

	buf = make([]byte, 4096)
	rbuf = make([]byte, 4096)

	dataLen, err = recvStartupPacket(cliConn, buf[:])
	if err != nil {
		zap.S().Infof("Recv first startup packet from client error: %s", err.Error())
		cliConn.Close()
		return -1, nil
	}
	pos += dataLen

	//zap.S().Infof("Recv from client %d bytes: %v", dataLen, recvBuf[0:dataLen])

	ver1 = binary.BigEndian.Uint16(buf[4:6])
	ver2 = binary.BigEndian.Uint16(buf[6:8])
	if ver1 == 1234 && ver2 == 5679 { /* 这是客户端发过来测试是否可以进行SSL通信*/
		rbuf[0] = 'N'
		_, err = cliConn.Write(rbuf[0:1])
		if err != nil {
			zap.S().Infof("Reply to client error: %s", err.Error())
			//srv_conn.Close()
			cliConn.Close()
			return -1, nil
		}
		/* 重新收数据包*/
		dataLen, err = recvStartupPacket(cliConn, buf[:])
		if err != nil {
			zap.S().Infof("Recv second startup packet from client error: %s", err.Error())
			cliConn.Close()
			return -1, nil
		}
		pos = dataLen
		ver1 = binary.BigEndian.Uint16(buf[4:6])
		ver2 = binary.BigEndian.Uint16(buf[6:8])
	}

	//zap.S().Infof("Protocol version is %d.%d", ver1, ver2)

	reqOpts := parseMulKeyValuePacket(buf[8:], dataLen-4)
	var user string
	var database string
	var ok bool

	user, ok = reqOpts["user"]
	if !ok {
		zap.S().Infof("No 'user' in startup packet!")
		return -1, nil
	}
	database, ok = reqOpts["database"]
	if !ok {
		zap.S().Infof("No 'database' in startup packet!")
		return -1, nil
	}

	var poolName string
	poolName = user + "." + database
	pool, ok = g_backend_pool[poolName]
	if !ok {
		zap.S().Infof("User(%s), database(%s) not in pool!", user, database)
		return -1, nil
	}

	var dbServerVersion string
	dbServerVersion, ok = g_pools_version[poolName]
	if !ok {
		zap.S().Infof("User(%s), database(%s), can not database server version!", user, database)
		return -1, nil
	}

	/*
		for k, v := range reqOpts {
			zap.S().Infof("%s=%s", k, v)
		}*/

	var randomSalt [4]byte
	_, err = rand.Read(randomSalt[0:])
	if err != nil {
		zap.S().Infof("Can not generate random salt: %s", err.Error())
		return -1, nil
	}

	//zap.S().Infof("Generate random salt is %v", randomSalt)

	rbuf[0] = 'R'
	binary.BigEndian.PutUint32(rbuf[1:], 12)
	binary.BigEndian.PutUint32(rbuf[5:], 5)
	copy(rbuf[9:], randomSalt[0:4])
	//zap.S().Infof("My reply should be %v", rbuf[0:13])
	_, err = cliConn.Write(rbuf[0:13])
	if err != nil {
		zap.S().Infof("Reply to client error: %s", err.Error())
		cliConn.Close()
		return -1, nil
	}

	/* 从客户端接收 加密后的md5值 */
	_, _, err = recvMessage(cliConn, buf[:])
	if err != nil {
		zap.S().Infof("Receive from client error: %s", err.Error())
		cliConn.Close()
		return -1, nil
	}

	if buf[0] != 'p' {
		zap.S().Infof("Expect from client recv password message, but recv %c message", buf[0])
		cliConn.Close()
		return -1, nil
	}
	//zap.S().Infof("Recv from client password message: %s", string(recvBuf[5:n]))
	md5AuthCalc(rbuf[:], user, pool.Conf.FePasswd, randomSalt[:])
	if bytes.Compare(rbuf[:35], buf[5:40]) == 0 {
		/* 回一个AuthenticationOk消息*/
		err = sendPqAuthenticationOk(cliConn)
		if err != nil {
			zap.S().Infof("Reply to client error: %s", err.Error())
			cliConn.Close()
			return -1, nil
		}

		parms := map[string]string{
			"client_encoding":   "UTF8",
			"DateStyle":         "ISO, MDY",
			"integer_datetimes": "on",
			"IntervalStyle":     "postgres",
			"server_encoding":   "UTF8",
			"server_version":    dbServerVersion,
			"TimeZone":          "PRC",
		}

		for key, value := range parms {
			err = sendPqParameterStatus(cliConn, key, value)
			if err != nil {
				zap.S().Infof("Reply to client error: %s", err.Error())
				cliConn.Close()
				return -1, nil
			}
		}

		err = sendPqBackendKeyData(cliConn, Pid, 232323)
		if err != nil {
			zap.S().Infof("Reply to client error: %s", err.Error())
			cliConn.Close()
			return -1, nil
		}

		err = sendPqReadyForQuery(cliConn, 'I')
		if err != nil {
			zap.S().Infof("Reply to client error: %s", err.Error())
			cliConn.Close()
			return -1, nil
		}
		return 0, pool
	}

	var errFields = [...]string{
		"SFATAL",
		"VFATAL",
		"C28P01",
		"Mpassword authentication failed",
		"Fcstech.go",
		"L458",
		"Rauth_failed",
	}

	err = sendPqErrorResponse(cliConn, errFields[:])
	if err != nil {
		zap.S().Infof("Reply to client error: %s", err.Error())
		cliConn.Close()
		return -1, nil
	}
	return 1, nil
}

/*获得数据库服务器的版本*/
func getDbServerVersion(conn net.Conn) (error, string) {
	var n int32
	var err error
	var isSuccess bool
	var packetLen int32
	var reqBuf []byte
	var recvBuf []byte
	sql := "show server_version;"
	var dataLen int32
	var server_version string
	var errMsg string
	var pos int32

	recvBuf = make([]byte, 1024)
	packetLen = int32(len(sql)) + 1 + 4
	reqBuf = make([]byte, packetLen+1+5)
	reqBuf[0] = 'Q'
	binary.BigEndian.PutUint32(reqBuf[1:], uint32(packetLen))
	pos = 5
	copy(reqBuf[pos:], []byte(sql))
	pos += int32(len(sql))
	reqBuf[pos] = '\000'
	pos++
	_, err = sendData(conn, reqBuf[:pos])
	if err != nil {
		return err, ""
	}

	isSuccess = true
	for {
		n, _, err = recvMessage(conn, recvBuf)
		if err != nil {
			return err, ""
		}

		if recvBuf[0] == 'E' {
			isSuccess = false
			printBackendErrorMessage(recvBuf[:n])
			errMsg = string(recvBuf)
		}

		//TCHDEBUG zap.S().Infof("Client(%d, B): RB(%d, %c): data=%v", ctx.Pid, ctx.pBackConn.Id, ctx.recvBuf[0], ctx.recvBuf[:n])
		if recvBuf[0] == 'Z' {
			break
		}

		if recvBuf[0] == 'D' { /*这是返回的一行数据*/
			dataLen = int32(binary.BigEndian.Uint32(recvBuf[7:]))
			server_version = string(recvBuf[11 : 11+dataLen])
		}
	}

	if isSuccess {
		return nil, server_version
	} else {
		return errors.New(errMsg), ""
	}
}
